from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional

from executorch.exir import EdgeProgramManager

from torch.export import ExportedProgram


class StageType(Enum):
    QUANTIZE = 0
    EXPORT = 1
    RUN_PASSES = 2
    TO_EDGE = 3
    TO_EDGE_TRANSFORM_AND_LOWER = 4
    PARTITION = 5
    TO_EXECUTORCH = 6
    SERIALIZE = 7
    INITIAL_MODEL = 8


class Stage(ABC):
    """
    Interface for a Stage in the PT2.0 lowering pipeline
    """

    @abstractmethod
    def stage_type(self) -> StageType:
        """
        Returns the type of the stage.
        """
        pass

    @abstractmethod
    def run(self, artifact, inputs):
        """
        Executes this stage, generates the 'artifact', for later stages.
        """
        pass

    @property
    @abstractmethod
    def artifact(self):
        """
        Returns the artifact generated by this stage. To be used by the next stage in the pipeline.
        """
        pass

    @property
    @abstractmethod
    def graph_module(self):
        """
        Return the artifact's graph module for this stage
        """
        pass

    def run_artifact(self, inputs):
        """
        Returns the output of calling the artifact generated by this stage with inputs
        """
        if isinstance(self.artifact, ExportedProgram):
            return self.artifact(*inputs)
        else:
            return self.artifact.exported_program().module()(*inputs)

    # Debug Tools for stages
    def artifact_str(self):
        """
        Return string printable artifact for this stage
        """
        if isinstance(self.artifact, EdgeProgramManager):
            return self.artifact.exported_program()
        return self.artifact

    def stage_banner(self):
        """
        Returns banner string for this stage
        """
        return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n"

    def dump_artifact(self, path_to_dump: Optional[str]):
        """
        Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal
        """
        if path_to_dump:
            with open(path_to_dump, "a") as fp:
                fp.write(str(self.stage_banner() + "\n"))
                fp.write(str(self.artifact_str()))
        else:
            print(self.stage_banner() + "\n")
            print(self.artifact_str())
